Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for gh-1468 in arithmetic reduction when type promotion is needed #1470

Merged
merged 4 commits into from
Nov 8, 2023

Conversation

oleksandr-pavlyk
Copy link
Collaborator

@oleksandr-pavlyk oleksandr-pavlyk commented Nov 7, 2023

Closes gh-1468

This PR changes _reduce_over_axis when an implementation to reduce using requested dtype for given input type does not exist. The change casts input array into a temporary array of requested dtype and calls reduction on the temporary (silent assumption that implementation to reduce input array using the same input dtype is available).

  • Have you provided a meaningful PR description?
  • Have you added a test, reproducer or referred to an issue with a reproducer?
  • Have you tested your changes locally for CPU and GPU devices?
  • Have you made sure that new changes do not introduce compiler warnings?
  • Have you checked performance impact of proposed changes?
  • If this PR is a work in progress, are you opening the PR as a draft?

Removed mention of dtype kwarg in usage line
Function _reduce_over_axis promotes input array to requested
result data type and carries out reduction computation in that
data type. This is done in dtype if implementation supports it.

If implementation does not support the requested dtype, we reduce
in the default_dtype, and cast to the request dtype afterwards.
Copy link

github-actions bot commented Nov 7, 2023

Copy link

github-actions bot commented Nov 7, 2023

Array API standard conformance tests for dpctl=0.15.1dev1=py310ha25a700_4 ran successfully.
Passed: 935
Failed: 65
Skipped: 119

Copy link

github-actions bot commented Nov 7, 2023

Array API standard conformance tests for dpctl=0.15.1dev1=py310ha25a700_3 ran successfully.
Passed: 935
Failed: 65
Skipped: 119

@coveralls
Copy link
Collaborator

Coverage Status

coverage: 85.796% (+0.01%) from 85.785%
when pulling ff9b5eb on fix-gh-1468-reduction
into 9018745 on master.

Copy link

github-actions bot commented Nov 7, 2023

Array API standard conformance tests for dpctl=0.15.1dev1=py310ha25a700_4 ran successfully.
Passed: 935
Failed: 65
Skipped: 119

@@ -118,7 +118,7 @@ def _reduction_over_axis(
dpt.full(
res_shape,
_identity,
dtype=_default_reduction_type_fn(inp_dt, q),
dtype=dtype,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this change, the call to astype is no longer necessary.

It also means that when logsumexp or reduce_hypot reduce over an empty axis or array (e.g., dpt.logsumexp(dpt.ones((1, 0, 1), dtype="i4"), axis=1, dtype="i4")) you get OverflowError instead of silently casting the identity to the output type.

For now, the astype can be removed. I've experimented with removing this branch in #1465 too.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and it should not be dtype, it should be res_dt.

Copy link

github-actions bot commented Nov 8, 2023

Array API standard conformance tests for dpctl=0.15.1dev1=py310ha25a700_5 ran successfully.
Passed: 935
Failed: 65
Skipped: 119

Copy link
Collaborator

@ndgrigorian ndgrigorian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @oleksandr-pavlyk !

@oleksandr-pavlyk oleksandr-pavlyk merged commit dbab3fe into master Nov 8, 2023
26 checks passed
@oleksandr-pavlyk oleksandr-pavlyk deleted the fix-gh-1468-reduction branch November 8, 2023 03:16
@vtavana vtavana mentioned this pull request Nov 23, 2023
6 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

dpt.sum is not doing computation in requested dtype on Windows
3 participants